{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vkdnLiKk71g-"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "0asMuNro71hA"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iPFgLeZIsZ3Q"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/federated/tutorials/federated_select\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/federated/blob/master/docs/tutorials/federated_select.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/federated/blob/master/docs/tutorials/federated_select.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/federated/docs/tutorials/federated_select.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T94owwmP-41H"
      },
      "source": [
        "# Sending Different Data To Particular Clients With tff.federated_select"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2K2GBCD2G6P8"
      },
      "source": [
        "This tutorial demonstrates how to implement custom federated algorithms in TFF that require sending different data to different clients. You may already be familiar with `tff.federated_broadcast` which sends a single server-placed value to all clients. This tutorial focuses on cases where different parts of a server-based value are sent to different clients. This may be useful for dividing up parts of a model across different clients in order to avoid sending the whole model to any single client.\n",
        "\n",
        "Let's get started by importing both `tensorflow` and `tensorflow_federated`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9LcC1AwjoqfR"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install --quiet --upgrade tensorflow-federated-nightly\n",
        "!pip install --quiet --upgrade nest-asyncio\n",
        "\n",
        "import nest_asyncio\n",
        "nest_asyncio.apply()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YVyimqc7qHCn"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v35NnHqL_Zci"
      },
      "source": [
        "## Sending Different Values Based On Client Data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S169M4-qH9Y9"
      },
      "source": [
        "Consider the case where we have some server-placed list from which we want to send a few elements to each client based on some client-placed data. For example, a list of strings on the server, and on the clients, a comma-separated list of indices to download. We can implement that as follows:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rc_XhL7h_vQC"
      },
      "outputs": [],
      "source": [
        "list_of_strings_type = tff.TensorType(tf.string, [None])\n",
        "# We only ever send exactly two values to each client. The number of keys per\n",
        "# client must be a fixed number across all clients.\n",
        "number_of_keys_per_client = 2\n",
        "keys_type = tff.TensorType(tf.int32, [number_of_keys_per_client])\n",
        "get_size = tff.tf_computation(lambda x: tf.size(x))\n",
        "select_fn = tff.tf_computation(lambda val, index: tf.gather(val, index))\n",
        "client_data_type = tf.string\n",
        "\n",
        "# A function from our client data to the indices of the values we'd like to\n",
        "# select from the server.\n",
        "@tff.tf_computation(client_data_type)\n",
        "@tff.check_returns_type(keys_type)\n",
        "def keys_for_client(client_string):\n",
        "  # We assume our client data is a single string consisting of exactly three\n",
        "  # comma-separated integers indicating which values to grab from the server.\n",
        "  split = tf.strings.split([client_string], sep=',')[0]\n",
        "  return tf.strings.to_number([split[0], split[1]], tf.int32)\n",
        "\n",
        "@tff.tf_computation(tff.SequenceType(tf.string))\n",
        "@tff.check_returns_type(tf.string)\n",
        "def concatenate(values):\n",
        "  def reduce_fn(acc, item):\n",
        "    return tf.cond(tf.math.equal(acc, ''),\n",
        "                   lambda: item,\n",
        "                   lambda: tf.strings.join([acc, item], ','))\n",
        "  return values.reduce('', reduce_fn)\n",
        "\n",
        "@tff.federated_computation(tff.type_at_server(list_of_strings_type), tff.type_at_clients(client_data_type))\n",
        "def broadcast_based_on_client_data(list_of_strings_at_server, client_data):\n",
        "  keys_at_clients = tff.federated_map(keys_for_client, client_data)\n",
        "  max_key = tff.federated_map(get_size, list_of_strings_at_server)\n",
        "  values_at_clients = tff.federated_select(keys_at_clients, max_key, list_of_strings_at_server, select_fn)\n",
        "  value_at_clients = tff.federated_map(concatenate, values_at_clients)\n",
        "  return value_at_clients"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QpdKyL77JKea"
      },
      "source": [
        "Then we can simulate our computation by providing the server-placed list of strings as well as string data for each client:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aneU54u0F6al"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[<tf.Tensor: shape=(), dtype=string, numpy=b'a,b'>,\n",
              " <tf.Tensor: shape=(), dtype=string, numpy=b'b,c'>,\n",
              " <tf.Tensor: shape=(), dtype=string, numpy=b'c,a'>]"
            ]
          },
          "execution_count": 49,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "client_data = ['0,1', '1,2', '2,0']\n",
        "broadcast_based_on_client_data(['a', 'b', 'c'], client_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TeLPCh8z_BJJ"
      },
      "source": [
        "## Sending A Randomized Element To Each Client"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ADjD0poWJkIj"
      },
      "source": [
        "Alternatively, it may be useful to send a random portion of the server data to each client. We can implement that by first generating a random key on each client and then following a similar selection process to the one used above:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "texCnO6Erds4"
      },
      "outputs": [],
      "source": [
        "@tff.tf_computation(tf.int32)\n",
        "@tff.check_returns_type(tff.TensorType(tf.int32, [1]))\n",
        "def get_random_key(max_key):\n",
        "  return tf.random.uniform(shape=[1], minval=0, maxval=max_key, dtype=tf.int32)\n",
        "\n",
        "list_of_strings_type = tff.TensorType(tf.string, [None])\n",
        "get_size = tff.tf_computation(lambda x: tf.size(x))\n",
        "select_fn = tff.tf_computation(lambda val, index: tf.gather(val, index))\n",
        "\n",
        "@tff.tf_computation(tff.SequenceType(tf.string))\n",
        "@tff.check_returns_type(tf.string)\n",
        "def get_last_element(sequence):\n",
        "  return sequence.reduce('', lambda _initial_state, val: val)\n",
        "\n",
        "@tff.federated_computation(tff.type_at_server(list_of_strings_type))\n",
        "def broadcast_random_element(list_of_strings_at_server):\n",
        "  max_key_at_server = tff.federated_map(get_size, list_of_strings_at_server)\n",
        "  max_key_at_clients = tff.federated_broadcast(max_key_at_server)\n",
        "  key_at_clients = tff.federated_map(get_random_key, max_key_at_clients)\n",
        "  random_string_sequence_at_clients = tff.federated_select(\n",
        "      key_at_clients, max_key_at_server, list_of_strings_at_server, select_fn)\n",
        "  # Even though we only passed in a single key, `federated_select` returns a\n",
        "  # sequence for each client. We only care about the last (and only) element.\n",
        "  random_string_at_clients = tff.federated_map(get_last_element, random_string_sequence_at_clients)\n",
        "  return random_string_at_clients"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eCgbnWznJxVq"
      },
      "source": [
        "Since our `broadcast_random_element` function doesn't take in any client-placed data, we have to configure the TFF Simulation Runtime with a default number of clients to use:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N70yh3i6vYoy"
      },
      "outputs": [],
      "source": [
        "tff.backends.native.set_local_execution_context(default_num_clients=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TF1OttS2J9b4"
      },
      "source": [
        "Then we can simulate the selection. We can change `default_num_clients` above and the list of strings below to generate different results, or simply re-run the computation to generate different random outputs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lowrkwE09mIe"
      },
      "outputs": [],
      "source": [
        "broadcast_random_element(tf.convert_to_tensor(['foo', 'bar', 'baz']))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "federated_select.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
